;;; This implementation was initially roughly based on the following implementation
;;;   https://github.com/Another-Ghost/3D-Collision-Detection-and-Resolution-Using-GJK-and-EPA
;;; Though it has been largely modified and extended.

(in-package #:org.shirakumo.fraf.trial.gjk)

(defconstant GJK-ITERATIONS 64)
(defconstant EPA-ITERATIONS 64)
(defconstant EPA-TOLERANCE 0.0001)
(defconstant EPA-MAX-FACES 256)
(defconstant EPA-MAX-LOOSE-EDGES 128)

;;;; GJK main algorithm
;; TODO: avoid consing from v-

(declaim (ftype (function (vec3 vec3 vec3 vec3 &optional vec3) (or null vec3)) barycentric))
(defun barycentric (a b c p &optional (res (vec3)))
  (declare (optimize speed (safety 0)))
  (declare (type vec3 a b c p res))
  ;; Compute the barycentric coordinates of P within the triangle spanned by A B C
  (flet ((df (x) (coerce x 'double-float))
         (sf (x) (coerce x 'single-float)))
    (declare (inline sf df))
    (let* ((v0 (v- b a))
           (v1 (v- c a))
           (v2 (v- p a))
           (d00 (df (v. v0 v0)))
           (d01 (df (v. v0 v1)))
           (d11 (df (v. v1 v1)))
           (d20 (df (v. v2 v0)))
           (d21 (df (v. v2 v1)))
           (denom (- (* d00 d11) (* d01 d01))))
      (declare (dynamic-extent v0 v1 v2))
      (unless (zerop denom)
        (let ((v (/ (- (* d11 d20) (* d01 d21)) denom))
              (w (/ (- (* d00 d21) (* d01 d20)) denom)))
          (vsetf res (sf (- 1d0 v w)) (sf v) (sf w)))))))

(declaim (inline plane-normal))
(defun plane-normal (a b c &optional (res (vec3)))
  (declare (optimize speed (safety 0)))
  (declare (type vec3 a b c res))
  (let ((ba (v- b a))
        (ca (v- c a)))
    (declare (dynamic-extent ba ca))
    (nvunit* (!vc res ba ca))))

(declaim (inline point))
(defstruct (point
            (:constructor point (&optional (varr3 (make-array 3 :element-type 'single-float))))
            (:include vec3)
            (:copier NIL)
            (:predicate NIL))
  (a (vec3 0 0 0) :type vec3)
  (b (vec3 0 0 0) :type vec3))

(declaim (inline p<-))
(defun p<- (target src)
  (declare (type point target src))
  (v<- target src)
  (v<- (point-a target) (point-a src))
  (v<- (point-b target) (point-b src))
  target)

(defun search-point (p +dir a b)
  (declare (optimize speed))
  (declare (type point p))
  (declare (type vec3 +dir))
  (let ((-dir (v- +dir)))
    (declare (dynamic-extent -dir))
    (trial:global-support-function b +dir (point-b p))
    (trial:global-support-function a -dir (point-a p))
    (!v- p (point-b p) (point-a p))))

(defun %gjk (a b dir s0 s1 s2 s3)
  (declare (type trial:primitive a b))
  (declare (type point dir s0 s1 s2 s3))
  (declare (optimize speed))
  (let ((s12 (point))
        (t1 (vec3)))
    (declare (dynamic-extent s12 t1))
    (trial::global-location a dir)
    (trial::global-location b s0)
    (nv- dir s0)
    (search-point s2 dir a b)
    (!v- dir s2)
    (search-point s1 dir a b)
    (unless (< (v. s1 dir) 0)
      (!v- s12 s2 s1)
      (!vc dir (!vc dir s12 (!v- t1 s1)) s12)
      (when (v= 0 dir)
        (!vc dir s12 +vx3+)
        (when (v= 0 dir)
          (!vc dir s12 +vz3+)))
      (loop with dim of-type (unsigned-byte 8) = 2
            for i from 0 below GJK-ITERATIONS
            do (search-point s0 dir a b)
               (when (< (v. s0 dir) 0)
                 (return NIL))
               (incf dim)
               (cond ((= 3 dim)
                      (setf dim (update-simplex s0 s1 s2 s3 dir)))
                     ((null (test-simplex s0 s1 s2 s3 dir))
                      (setf dim 3))
                     (T
                      (return T)))
            finally (trial::dbg "GJK Overflow")))))

(trial:define-distance (trial:primitive trial:primitive)
  (let ((dir (point)) (s0 (point)) (s1 (point)) (s2 (point)) (s3 (point)))
    (declare (dynamic-extent dir s0 s1 s2 s3))
    ;; TODO: this could be far more efficient with a more optimised routine
    ;;       See https://dyn4j.org/2010/04/gjk-distance-closest-points/
    (if (%gjk a b dir s0 s1 s2 s3)
        0.0
        (abs (v. (nvunit dir) s0)))))

(defun detect-hits (a b hits start end &optional (epa T))
  (declare (type trial:primitive a b))
  (declare (type (unsigned-byte 32) start end))
  (declare (type simple-vector hits))
  (declare (optimize speed))
  (when (<= end start)
    (return-from detect-hits start))
  (let ((hit (aref hits start))
        (dir (point)) (s0 (point)) (s1 (point)) (s2 (point)) (s3 (point)))
    (declare (dynamic-extent dir s0 s1 s2 s3))
    (cond ((and (%gjk a b dir s0 s1 s2 s3)
                (if epa (epa s0 s1 s2 s3 a b hit) T))
           (unless epa
             (vsetf (trial:hit-normal hit) 0 1 0))
           (trial:finish-hit hit a b)
           (1+ start))
          (T
           start))))

(defun update-simplex (s0 s1 s2 s3 dir)
  (declare (optimize speed (safety 0)))
  (declare (type point s0 s1 s2 s3))
  (declare (type vec3 dir))
  (let ((n (vec3)) (ao (v- s0)) (t1 (vec3)) (t2 (vec3)))
    (declare (dynamic-extent n ao t1 t2))
    (!vc n (!v- t1 s1 s0) (!v- t2 s2 s0))
    (cond ((< 0 (v. ao (!vc t2 (!v- t1 s1 s0) n)))
           (p<- s2 s0)
           (!vc dir (!vc dir (!v- t1 s1 s0) ao) (!v- t2 s1 s0))
           2)
          ((< 0 (v. ao (!vc t2 n (!v- t1 s2 s0))))
           (p<- s1 s0)
           (!vc dir (!vc dir (!v- t1 s2 s0) ao) (!v- t2 s2 s0))
           2)
          ((< 0 (v. n ao))
           (p<- s3 s2)
           (p<- s2 s1)
           (p<- s1 s0)
           (v<- dir n)
           3)
          (T
           (p<- s3 s1)
           (p<- s1 s0)
           (v<- dir n)
           (nv- dir)
           3))))

(defun test-simplex (s0 s1 s2 s3 dir)
  (declare (optimize speed (safety 0)))
  (declare (type point s0 s1 s2 s3))
  (declare (type vec3 dir))
  (let ((abc (vec3)) (acd (vec3)) (adb (vec3)) (ao (v- s0))
        (t1 (vec3)) (t2 (vec3)))
    (declare (dynamic-extent abc acd adb ao t1 t2))
    (!vc abc (!v- t1 s1 s0) (!v- t2 s2 s0))
    (!vc acd (!v- t1 s2 s0) (!v- t2 s3 s0))
    (!vc adb (!v- t1 s3 s0) (!v- t2 s1 s0))
    (cond ((< 0 (v. abc ao))
           (p<- s3 s2)
           (p<- s2 s1)
           (p<- s1 s0)
           (v<- dir abc)
           NIL)
          ((< 0 (v. acd ao))
           (p<- s1 s0)
           (v<- dir acd)
           NIL)
          ((< 0 (v. adb ao))
           (p<- s2 s3)
           (p<- s3 s1)
           (p<- s1 s0)
           (v<- dir adb)
           NIL)
          (T
           T))))

;;;; EPA for depth and normal computation
;;; FIXME: stack allocation bullshit
(defun epa (s0 s1 s2 s3 a b hit)
  (declare (optimize speed (safety 1)))
  (declare (type point s0 s1 s2 s3))
  (declare (type trial:hit hit))
  (let ((faces (make-array (* 4 EPA-MAX-FACES)))
        (loose-edges (make-array (* 2 EPA-MAX-LOOSE-EDGES)))
        (num-faces 4) (closest-face 0) (min-dist 0.0)
        (search-dir (vec3)) (p (point)))
    (declare (dynamic-extent faces loose-edges search-dir p))
    (declare (type (unsigned-byte 16) num-faces))
    (dotimes (i (* 4 EPA-MAX-FACES))
      (setf (aref faces i) p)
      (setq p (point)))
    (dotimes (i (* 2 EPA-MAX-LOOSE-EDGES))
      (setf (aref loose-edges i) p)
      (setq p (point)))
    (macrolet ((v (f v)
                 `(the point (aref faces (+ (* 4 ,f) ,v))))
               (e (e v)
                 `(the point (aref loose-edges (+ (* 2 ,e) ,v)))))
      ;; Construct the initial polytope
      ;; The FACES array contains the packed vertices and normal of each face
      (p<- (v 0 0) s0)
      (p<- (v 0 1) s1)
      (p<- (v 0 2) s2)
      (plane-normal s0 s1 s2 (v 0 3))
      (p<- (v 1 0) s0)
      (p<- (v 1 1) s2)
      (p<- (v 1 2) s3)
      (plane-normal s0 s2 s3 (v 1 3))
      (p<- (v 2 0) s0)
      (p<- (v 2 1) s3)
      (p<- (v 2 2) s1)
      (plane-normal s0 s3 s1 (v 2 3))
      (p<- (v 3 0) s1)
      (p<- (v 3 1) s3)
      (p<- (v 3 2) s2)
      (plane-normal s1 s3 s2 (v 3 3))
      ;; Main iteration loop to find the involved faces
      (dotimes (i EPA-ITERATIONS)
        ;; Find the closest face in our set of known polytope faces
        (setf min-dist (v. (v 0 0) (v 0 3)))
        (setf closest-face 0)
        (loop for i from 1 below num-faces
              for dist = (v. (v i 0) (v i 3))
              do (when (< dist min-dist)
                   (setf min-dist dist)
                   (setf closest-face i)))
        (v<- search-dir (v closest-face 3))
        ;; Find a new direction to search in via the support functions
        (search-point p search-dir a b)
        (when (< (- (v. p search-dir) min-dist) EPA-TOLERANCE)
          (return))
        ;; We still haven't found a face that's good enough, so expand the
        ;; polytope from our current face set
        (let ((num-loose-edges 0)
              (i 0))
          (declare (type (unsigned-byte 16) num-loose-edges i))
          ;; Find triangles facing our current search point
          (loop (when (<= num-faces i) (return))
                (cond ((< 0 (v. (v i 3) (!v- search-dir p (v i 0))))
                       ;; This face will be removed, so add the edges back
                       (loop for j from 0 below 3
                             for edge-a = (v i j)
                             for edge-b = (v i (mod (1+ j) 3))
                             for edge-found-p = NIL
                             do (dotimes (k num-loose-edges)
                                  (when (and (v= (e k 1) edge-a)
                                             (v= (e k 0) edge-b))
                                    (decf num-loose-edges)
                                    (p<- (e k 0) (e num-loose-edges 0))
                                    (p<- (e k 1) (e num-loose-edges 1))
                                    (setf edge-found-p T)
                                    (return)))
                                ;; The edge isn't already in the loose edge list
                                ;; so add it back in now.
                                (unless edge-found-p
                                  (when (<= EPA-MAX-LOOSE-EDGES num-loose-edges)
                                    (trial::dbg "EPA Edges Overflow")
                                    (return))
                                  (p<- (e num-loose-edges 0) edge-a)
                                  (p<- (e num-loose-edges 1) edge-b)
                                  (incf num-loose-edges)))
                       ;; This face is no longer facing our search point, so we remove
                       ;; it by replacing it with the tail face
                       (decf num-faces)
                       (p<- (v i 0) (v num-faces 0))
                       (p<- (v i 1) (v num-faces 1))
                       (p<- (v i 2) (v num-faces 2))
                       (v<- (v i 3) (v num-faces 3)))
                      (T
                       (incf i))))
          ;; Expand the polytope with the search point added to the new loose edge faces
          (dotimes (i num-loose-edges)
            (when (<= EPA-MAX-FACES num-faces)
              (trial::dbg "EPA Faces Overflow")
              (return))
            (p<- (v num-faces 0) (e i 0))
            (p<- (v num-faces 1) (e i 1))
            (p<- (v num-faces 2) p)
            (plane-normal (v num-faces 0) (v num-faces 1) (v num-faces 2) (v num-faces 3))
            ;; Check the CCW winding order via normal test
            (when (< (+ (v. (v num-faces 0) (v num-faces 3)) 0.000001) 0)
              (rotatef (v num-faces 0) (v num-faces 1))
              (nv- (v num-faces 3)))
            (incf num-faces)))
        (when (= (1+ i) EPA-ITERATIONS)
          (trial::dbg "EPA Overflow")))

      ;; Compute the actual intersection
      ;; If we did not converge, we just use the closest face we reached
      (%epa-finish hit (v closest-face 0) (v closest-face 1) (v closest-face 2) (v closest-face 3)))))

(defun %epa-finish (hit c0 c1 c2 n)
  (declare (type point c0 c1 c2 n))
  (declare (type trial:hit hit))
  (declare (optimize speed (safety 0)))
  (let ((p (vec3)) (a-point (vec3)) (b-point (vec3)))
    (declare (dynamic-extent p a-point b-point))
    (when (barycentric c0 c1 c2 (nv* n (v. n c0)) p)
      (nv+* a-point (point-a c0) (vx p))
      (nv+* a-point (point-a c1) (vy p))
      (nv+* a-point (point-a c2) (vz p))
      (nv+* b-point (point-b c0) (vx p))
      (nv+* b-point (point-b c1) (vy p))
      (nv+* b-point (point-b c2) (vz p))
      (v<- (trial:hit-location hit) a-point)
      (v<- (trial:hit-normal hit) b-point)
      (nv- (trial:hit-normal hit) a-point)
      (setf (trial:hit-depth hit) (vlength (trial:hit-normal hit)))
      (if (= 0.0 (trial:hit-depth hit))
          (v<- (trial:hit-normal hit) +vy3+)
          (nv/ (trial:hit-normal hit) (trial:hit-depth hit))))))
